package tensorflow

import (
	"encoding/json"
	"fmt"
	"os"
	"path/filepath"

	pb "github.com/wamuir/graft/tensorflow/core/protobuf/for_core_protos_go_proto"
	"google.golang.org/protobuf/proto"

	"github.com/photoprism/photoprism/pkg/clean"
)

// ExpectedChannels defines the expected number of channels.
// This is a fixed value because a standard seems to have been
// defined for input images as "what decodeImage returns".
const ExpectedChannels = 3

// Interval of allowed values.
type Interval struct {
	Start  float32  `yaml:"Start,omitempty" json:"start,omitempty"`
	End    float32  `yaml:"End,omitempty" json:"end,omitempty"`
	Mean   *float32 `yaml:"Mean,omitempty" json:"mean,omitempty"`
	StdDev *float32 `yaml:"StdDev,omitempty" json:"stdDev,omitempty"`
}

// Size returns the size/mean of the interval.
func (i Interval) Size() float32 {
	return i.End - i.Start
}

// Offset returns the offset of the interval.
func (i Interval) Offset() float32 {
	if i.StdDev == nil {
		return i.Start
	} else {
		return *i.StdDev
	}
}

// StandardInterval returns the standard interval, i.e.
// the range of values returned by decodeImage in [0, 1].
func StandardInterval() *Interval {
	return &Interval{
		Start: 0.0,
		End:   1.0,
	}
}

// ResizeOperation represents resizing operations for images.
// JSON and YAML functions are provided to make configuration files user-friendly.
type ResizeOperation int

const (
	// UndefinedResizeOperation indicates that no resize strategy was specified.
	UndefinedResizeOperation ResizeOperation = iota
	// ResizeBreakAspectRatio resizes without preserving aspect ratio.
	ResizeBreakAspectRatio
	// CenterCrop crops the center region after resizing to fill the target size.
	CenterCrop
	// Padding resizes while preserving aspect ratio and pads the rest.
	Padding
)

func (o ResizeOperation) String() string {
	switch o {
	case UndefinedResizeOperation:
		return "Undefined"
	case ResizeBreakAspectRatio:
		return "ResizeBreakAspectRatio"
	case CenterCrop:
		return "CenterCrop"
	case Padding:
		return "Padding"
	default:
		return "Unknown"
	}
}

// NewResizeOperation parses a string into a ResizeOperation.
func NewResizeOperation(s string) (ResizeOperation, error) {
	switch s {
	case "Undefined":
		return UndefinedResizeOperation, nil
	case "ResizeBreakAspectRatio":
		return ResizeBreakAspectRatio, nil
	case "CenterCrop":
		return CenterCrop, nil
	case "Padding":
		return Padding, nil
	default:
		return UndefinedResizeOperation, fmt.Errorf("invalid operation %s", s)
	}
}

// MarshalJSON encodes the resize operation as its string name.
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
	return json.Marshal(o.String())
}

// UnmarshalJSON decodes a resize operation from its string representation.
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
	var s string
	if err := json.Unmarshal(data, &s); err != nil {
		return err
	}

	val, err := NewResizeOperation(s)
	if err != nil {
		return err
	}
	*o = val

	return nil
}

// MarshalYAML encodes the resize operation for YAML output.
func (o ResizeOperation) MarshalYAML() (any, error) {
	return o.String(), nil
}

// UnmarshalYAML decodes the resize operation from YAML input.
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
	var s string
	if err := unmarshal(&s); err != nil {
		return err
	}

	val, err := NewResizeOperation(s)
	if err != nil {
		return err
	}
	*o = val
	return nil
}

// ColorChannelOrder represents the order of the model's input vectors.
// JSON and YAML functions are provided to make the configuration files user-friendly.
type ColorChannelOrder int

const (
	// UndefinedOrder leaves channel order unspecified, defaulting to RGB.
	UndefinedOrder ColorChannelOrder = 0
	// RGB represents Red-Green-Blue channel order.
	RGB = 123
	// RBG represents Red-Blue-Green channel order.
	RBG = 132
	// GRB represents Green-Red-Blue channel order.
	GRB = 213
	// GBR represents Green-Blue-Red channel order.
	GBR = 231
	// BRG represents Blue-Red-Green channel order.
	BRG = 312
	// BGR represents Blue-Green-Red channel order.
	BGR = 321
)

// Indices returns the zero-based indices of the R, G, and B channels.
func (o ColorChannelOrder) Indices() (r, g, b int) {
	i := int(o)

	if i == 0 {
		i = 123
	}

	for idx := 0; i > 0 && idx < 3; idx++ {
		remainder := i % 10
		i /= 10

		switch remainder {
		case 1:
			r = 2 - idx
		case 2:
			g = 2 - idx
		case 3:
			b = 2 - idx
		}
	}

	return
}

func (o ColorChannelOrder) String() string {
	value := int(o)

	if value == 0 {
		value = 123
	}

	convert := func(remainder int) string {
		switch remainder {
		case 1:
			return "R"
		case 2:
			return "G"
		case 3:
			return "B"
		default:
			return "?"
		}
	}

	result := ""
	for value > 0 {
		remainder := value % 10
		value /= 10

		result = convert(remainder) + result
	}

	return result
}

// NewColorChannelOrder parses a string (e.g., "RGB") into a ColorChannelOrder.
func NewColorChannelOrder(val string) (ColorChannelOrder, error) {
	if len(val) != 3 {
		return UndefinedOrder, fmt.Errorf("invalid length, expected 3")
	}

	convert := func(c rune) int {
		switch c {
		case 'R':
			return 1
		case 'G':
			return 2
		case 'B':
			return 3
		default:
			return 0
		}
	}

	result := 0
	for _, c := range val {
		index := convert(c)
		if index == 0 {
			return UndefinedOrder, fmt.Errorf("invalid val %c", c)
		}
		result = result*10 + index
	}
	return ColorChannelOrder(result), nil
}

// MarshalJSON encodes the channel order as its string name.
func (o ColorChannelOrder) MarshalJSON() ([]byte, error) {
	return json.Marshal(o.String())
}

// UnmarshalJSON decodes a channel order from its string representation.
func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error {
	var s string
	if err := json.Unmarshal(data, &s); err != nil {
		return err
	}

	val, err := NewColorChannelOrder(s)
	if err != nil {
		return err
	}
	*o = val

	return nil
}

// MarshalYAML encodes the channel order for YAML output.
func (o ColorChannelOrder) MarshalYAML() (any, error) {
	return o.String(), nil
}

// UnmarshalYAML decodes the channel order from YAML input.
func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
	var s string
	if err := unmarshal(&s); err != nil {
		return err
	}

	val, err := NewColorChannelOrder(s)
	if err != nil {
		return err
	}
	*o = val
	return nil
}

// ShapeComponent describes a single dimension of a model input shape.
// Usually this shape is (batch, resolution, resolution, channels) but sometimes it is not.
type ShapeComponent string

const (
	// ShapeBatch represents the batch dimension.
	ShapeBatch ShapeComponent = "Batch"
	// ShapeWidth represents the width dimension.
	ShapeWidth = "Width"
	// ShapeHeight represents the height dimension.
	ShapeHeight = "Height"
	// ShapeColor represents the color/channel dimension.
	ShapeColor = "Color"
)

// DefaultPhotoInputShape returns the standard BHWC input shape.
func DefaultPhotoInputShape() []ShapeComponent {
	return []ShapeComponent{
		ShapeBatch,
		ShapeHeight,
		ShapeWidth,
		ShapeColor,
	}
}

// PhotoInput represents an input description for a photo input for a model.
type PhotoInput struct {
	Name              string            `yaml:"Name,omitempty" json:"name,omitempty"`
	Intervals         []Interval        `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
	ResizeOperation   ResizeOperation   `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitempty"`
	ColorChannelOrder ColorChannelOrder `yaml:"ColorChannelOrder,omitempty" json:"inputOrder,omitempty"`
	OutputIndex       int               `yaml:"Index,omitempty" json:"index,omitempty"`
	Height            int64             `yaml:"Height,omitempty" json:"height,omitempty"`
	Width             int64             `yaml:"Width,omitempty" json:"width,omitempty"`
	Shape             []ShapeComponent  `yaml:"Shape,omitempty" json:"shape,omitempty"`
}

// IsDynamic checks if image dimensions are not defined, so the model accepts any size.
func (p PhotoInput) IsDynamic() bool {
	return p.Height == -1 && p.Width == -1
}

// Resolution returns the input image resolution based on the image width or height if the width is undefined.
func (p PhotoInput) Resolution() int {
	if p.Width > 0 {
		return int(p.Width)
	}

	return int(p.Height)
}

// SetResolution sets the input image width and height based on the resolution in pixels (max width and height).
func (p *PhotoInput) SetResolution(resolution int) {
	p.Height = int64(resolution)
	p.Width = int64(resolution)
}

// GetInterval returns the interval or the default one.
// If just one interval has been fixed, then we assume
// it is the same for every channel. If no intervals
// have been defined, the default [0, 1] is returned
func (p PhotoInput) GetInterval(channel int) *Interval {
	if len(p.Intervals) <= channel {
		if len(p.Intervals) == 1 {
			return &p.Intervals[0]
		}
		return StandardInterval()
	} else {
		return &p.Intervals[channel]
	}
}

// Merge other input with this.
func (p *PhotoInput) Merge(other *PhotoInput) {
	if p.Name == "" {
		p.Name = other.Name
	}

	if p.Intervals == nil && other.Intervals != nil {
		p.Intervals = other.Intervals
	}

	if p.OutputIndex == 0 {
		p.OutputIndex = other.OutputIndex
	}

	if p.Height == 0 {
		p.Height = other.Height
	}

	if p.Width == 0 {
		p.Width = other.Width
	}

	if p.Shape == nil && other.Shape != nil {
		p.Shape = other.Shape
	}

	if p.ResizeOperation == UndefinedResizeOperation {
		p.ResizeOperation = other.ResizeOperation
	}

	if p.ColorChannelOrder == UndefinedOrder {
		p.ColorChannelOrder = other.ColorChannelOrder
	}
}

// ModelOutput represents the expected model output.
type ModelOutput struct {
	Name          string `yaml:"Name,omitempty" json:"name,omitempty"`
	OutputIndex   int    `yaml:"Index,omitempty" json:"index,omitempty"`
	NumOutputs    int64  `yaml:"Outputs,omitempty" json:"outputs,omitempty"`
	OutputsLogits bool   `yaml:"Logits,omitempty" json:"logits,omitempty"`
}

// Merge merges other outputs with this output.
func (m *ModelOutput) Merge(other *ModelOutput) {
	if m.Name == "" {
		m.Name = other.Name
	}

	if m.OutputIndex == 0 {
		m.OutputIndex = other.OutputIndex
	}

	if m.NumOutputs == 0 {
		m.NumOutputs = other.NumOutputs
	}

	if !m.OutputsLogits {
		m.OutputsLogits = other.OutputsLogits
	}
}

// ModelInfo represents meta information for the model.
type ModelInfo struct {
	TFVersion string       `yaml:"-" json:"-"`
	Tags      []string     `yaml:"Tags" json:"tags"`
	Input     *PhotoInput  `yaml:"Input" json:"input"`
	Output    *ModelOutput `yaml:"Output" json:"output"`
}

// Merge other model info. In case of having information
// for a field, the current model will keep its current value
func (m *ModelInfo) Merge(other *ModelInfo) {
	if m.TFVersion == "" {
		m.TFVersion = other.TFVersion
	}

	if len(m.Tags) == 0 {
		m.Tags = other.Tags
	}

	if m.Input == nil {
		m.Input = other.Input
	} else if other.Input != nil {
		m.Input.Merge(other.Input)
	}

	if m.Output == nil {
		m.Output = other.Output
	} else if other.Output != nil {
		m.Output.Merge(other.Output)
	}
}

// IsComplete checks if the model input and output are defined.
func (m ModelInfo) IsComplete() bool {
	return m.Input != nil && m.Output != nil && m.Input.Shape != nil
}

// GetModelTagsInfo reads a SavedModel and returns its available meta graph tags.
func GetModelTagsInfo(savedModelPath string) ([]ModelInfo, error) {
	savedModel := filepath.Join(savedModelPath, "saved_model.pb")

	data, err := os.ReadFile(savedModel) //nolint:gosec // savedModel path derived from trusted model directory

	if err != nil {
		return nil, fmt.Errorf("vision: failed to read %s (%s)", clean.Path(savedModel), clean.Error(err))
	}

	model := new(pb.SavedModel)

	err = proto.Unmarshal(data, model)

	if err != nil {
		return nil, fmt.Errorf("vision: failed to unmarshal %s (%s)", clean.Path(savedModel), clean.Error(err))
	}

	models := make([]ModelInfo, 0)
	metas := model.GetMetaGraphs()

	for i := range metas {
		def := metas[i].GetMetaInfoDef()
		models = append(models, ModelInfo{
			TFVersion: def.GetTensorflowVersion(),
			Tags:      def.GetTags(),
		})
	}

	return models, nil
}
